package com.mimo.common.configuration.grpc.exception.handler;

import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.Ordered;

import com.mimo.common.logic.code.StatusCode;
import com.mimo.common.rpc.proto.BaseResponseProto.BaseResponse;

import io.grpc.stub.StreamObserver;

/**
 * 拦截所有@GrpcService服务接口，对异常进行捕获，并返回统一格式
 *
 * 
 */
@Aspect
public class GrpcExceptionHandlerAspect implements Ordered {
  private static final Logger log = LoggerFactory.getLogger(GrpcExceptionHandlerAspect.class);
  private static final BaseResponse ERROR_RESPONSE = BaseResponse.newBuilder().setCode(StatusCode.Error.getCode())
      .setMsg(StatusCode.Error.getMsg()).build();

  private Map<Method, Boolean> cache = new ConcurrentHashMap<>();

  @Pointcut("@within(net.devh.boot.grpc.server.service.GrpcService)  && within(com.mimo..*)")
  private void grpcExceptionAsp() {
    // nothing
  }

  @Around(value = "grpcExceptionAsp()")
  public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
    try {
      return joinPoint.proceed();
    } catch (Exception e) {
      handle(joinPoint, e);
    }
    return null;
  }

  @SuppressWarnings("unchecked")
  private void handle(ProceedingJoinPoint joinPoint, Exception e) throws Throwable {
    // 最后一个参数的类型是否为 StreamObserver<BaseResponse>
    MethodSignature methodSignature = MethodSignature.class.cast(joinPoint.getSignature());
    boolean isExpectedParameter = cache.computeIfAbsent(methodSignature.getMethod(), this::isExpectedParameter);

    if (isExpectedParameter) {
      log.error(e.getMessage(), e);

      Object[] args = joinPoint.getArgs();
      StreamObserver<BaseResponse> streamObserver = StreamObserver.class.cast(args[args.length - 1]);
      streamObserver.onNext(ERROR_RESPONSE);
      streamObserver.onCompleted();
    } else {
      log.error("GrpcExceptionHandlerAspect : not found argument of StreamObserver<BaseResponse>. Class[{}] Method[{}]",
          joinPoint.getTarget().getClass().getSimpleName(), joinPoint.getSignature().getName());

      // 无法处理，继续抛出
      throw e;
    }
  }

  private boolean isExpectedParameter(Method method) {
    // 获取最后一个形参类型
    Type[] parameterTypes = method.getGenericParameterTypes();
    Type lastParameterType = parameterTypes[parameterTypes.length - 1];

    // 检查形参类型是否为StreamObserver，且泛型实参为BaseResponse(或子类)
    if (ParameterizedType.class.isInstance(lastParameterType)) {
      ParameterizedType pt = ParameterizedType.class.cast(lastParameterType);
      Type[] actualTypeArguments = pt.getActualTypeArguments();
      return actualTypeArguments.length == 1 && Class.class.isInstance(actualTypeArguments[0])
          && BaseResponse.class.isAssignableFrom(Class.class.cast(actualTypeArguments[0]))
          && pt.getRawType() == StreamObserver.class; // StreamObserver<BaseResponse>
    }
    return false;
  }

  @Override
  public int getOrder() {
    return 0;
  }
}
